Tuto2 Tuto1 Paper

Automated Discovery Tools to Study Behavioral Competencies of Biological Networks [📝 writing in progress]

Authors Affiliation Published
Mayalen Etcheverry INRIA, Flowers team, Poietis May 1, 2023
Clément Moulin-Frier INRIA, Flowers team
Pierre-Yves Oudeyer INRIA, Flowers team
Michael Levin The Levin Lab, Tufts University Reproduce in Notebook

Introduction

TL;DR

This tutorial accompanies our paper Automated Discovery Tools to Study Behavioral Competencies of Biological Networks.
It is intented to walk you through the set of tools we use to (i) automatically explore the space of input stimulis of a given biological network in order to construct a behavioral catalogue of such system (ii) analyze the robustness of the discovered abilities in order to infer the network's navigation competencies in the transcriptional space of gene activation.

📝 How to follow this tutorial

💻 AutoDiscJax

Throughout this tutorial, we will be using the AutoDiscJax library, a library built on top of jax and equinox to facilitate automated experimentation and simulation of biological network pathways.

AutoDiscJax follows two main design principles: 1) Everything is a module, where a module is simply a parametrized function that takes inputs and returns outputs (and log_data). All autodiscjax modules adx.Module are implemented as equinox modules eqx.Module, which essentially allows to represent the function as a callable PyTree (and hence to be compatible with jax transformations) while keeping an intuitive API for model building (python class with a _call_ method). The only add-on with respect to equinox is that when instantiating a adx.Module, the user must specify the module's outputs PyTree structure, shape and dtype. 2) An experiment pipeline defines (i) how modules interact sequentially and exchange information, and (ii) what information should be collected and saved in the experiment history.

AutoDiscJax provides a handful of already-implement modules and pipelines to 1) Simulate biological networks while intervening on them according to our needs (see Part 1) 2) Automatically organize experimentation in those systems, by implementing a variety of exploration approaches such as random, optimization-driven and curiosity-driven search (see Part 2 and Part 3) 3) Analyze the discoveries of the exploration method, for instance by testing their robustness to various perturbations (see Part 4)

Finally, AutoDiscJax takes advantage of JAX mains features (just-in-time compilation, automatic vectorization and automatic differentation) which are especially advantageous for parallel experimentation and computational speedups, as well as gradient-based optimization.

Part 1: Numerical simulation (with interventions) of a GRN model

Ordinary differential equations (ODE) models are widely used to represent the behavior of complex biological processes. These ODE models are often experimentally determined as well as curated by biologists, e.g. combining experimental and data-drive methods to determine whether and how pairs of biomolecules interact. The resulting mathematical models are usually stored and exchanged using the Systems Biology Markup Language (SBML) language. Thanks to community efforts, large collections of published ODE models have been made publicly available on online databases, such as the BioModels website.

For this tutorial, we will be studying the gene regulatory network (GRN) model that describes the influence of RKIP on the ERK Signaling Pathway, that is described by this Cho et al.'s paper, and that is hosted on the BioModels database with the BIOMD0000000647 identifier and described by the following reaction graph: Drawing

Whereas the SBML file provides information about the different species, parameters and reactions involved in this model, we must use our own tools for carrying out simulation studies.

In the first part of this tutorial, we will be showing how to use the SBMLtoODEjax library to automatically retrieve and parse the SBML file into a python file, and the AutoDiscJax library to easily simulate the model and manipulate it based on our needs.

biomodel_idx = 647
observed_node_names = ["ERK", "RKIPP_RP"]

SBML to python conversion

Let's first use the SBMLtoODEjax library to download the SBML file and generate the corresponding python class that will allow us to simulate the model's dynamics.

out_model_sbml_filepath = f"data/biomodel_{biomodel_idx}.xml"
out_model_jax_filepath = f"data/biomodel_{biomodel_idx}.py"

# Donwload the SBML file
if not os.path.exists(out_model_sbml_filepath):
    model_xml_body = sbmltoodejax.biomodels_api.get_content_for_model(biomodel_idx)
    with open(out_model_sbml_filepath, 'w') as f:
        f.write(model_xml_body)

# Generation of the python class from the SBML file
if not os.path.exists(out_model_jax_filepath):
    model_data = sbmltoodejax.parse.ParseSBMLFile(out_model_sbml_filepath)
    sbmltoodejax.modulegeneration.GenerateModel(model_data, out_model_jax_filepath)

At this point, you should have created a biomodel_647.py file that contains several variables and functional modules as follows:

System Rollout module

Ok, let's now create our first module: the System Rollout. In short, the system rollout can be seen as a wrapper of the previously-created smbltoodejax's ModelRollout module, but this time allowing to apply all sorts of interventions and/or perturbations during the rollout simulation, and without having to modify the rollout codebase. We'll see how to do that in short, but first let's create our system rollout module and see how a simulation looks like.

When instancing the ModelStep module, we should specify the desired time step $\Delta t$ and ODE solver parameters $(atol, rtol, mxstep)$

# System Rollout Config
system_rollout_config = Dict()
system_rollout_config.system_type = "grn"
system_rollout_config.model_filepath = out_model_jax_filepath # path of model class that we just created using sbmltoodejax
system_rollout_config.atol = 1e-6 # parameters for the ODE solver 
system_rollout_config.rtol = 1e-12
system_rollout_config.mxstep = 1000
system_rollout_config.deltaT = 0.1 # the ODE solver will compute values every 0.1 second
system_rollout_config.n_secs = 1000 # number of a seconds of one rollout in the system 
system_rollout_config.n_system_steps = int(system_rollout_config.n_secs/system_rollout_config.deltaT) # total number of steps returned after a rollout

# Create the module
system_rollout = create_system_rollout_module(system_rollout_config)

# Get observed node ids
observed_node_ids = [system_rollout.grn_step.y_indexes[observed_node_names[0]], system_rollout.grn_step.y_indexes[observed_node_names[1]]]

Let's now simulate the module. By default, the rollout starts with the initial concentrations as given in the SBML file.

key, subkey = jrandom.split(key)
default_system_outputs, log_data = system_rollout(subkey) # all autodiscjax modules takes the state of the "random seed" as input (needed e.g. for operating random sampling operations within the module) 

Here is what we obtain when simulating the network with the default initial condtions (see original Figure 5 of Cho el al's paper):

Figure 1: Simulation results of the mathematical modeling for default initial condition.

We can also observe the resulting trajectories in phase space, also called transcriptional phase for gene regulatory networks. Let's for instance say that we are interested in observing the activation response of two specific nodes: ERK and RKIPP_RP. By plotting their trajectory in the transcriptional space, we can see that they start in point A(0,0) and navigate until they reach a steady state in point B(0.036,0.055).

Figure 2: Simulation results for default initial condition. Trajectory of nodes ERK and RKIPP_RP is shown in transcriptional space.

Applying interventions

We would now like to do several interventions on the system rollout to see how this influence the trajectory of the network in transcriptional space. The system rollout module allows us to do it by specifying an intervention function $(y^-, y, w^-, w, c^-, c, t^-, t) \mapsto (y, w, c, t)$ that is called at everystep during a rollout, where $(y^-,w^-,c^-,t^-)$ represent the variables values before the step (t-1) and $(y,w,c,t)$ represent the variable values that would be returned by the ModelStep function without intervention (and that are overwrited by the intervention function).

Several intervention functions are already provided in AutoDiscJax. In particular, the PiecewiseIntervention module allows to control the variables of the network with a piecewise-defined function, which allows to apply a broad range of interventions based on one's needs and constraints. The module configuration allows us to specify which variable(s) (stored in y, w or c) to intervene on, when (on which time intervals), and how (e.g. by setting the variable value using PiecewiseSetConstantIntervention or by adding to the variable's current value PiecewiseAddConstantIntervention).

Let's try it and see how our ERK-pathway network will react to the following interventions: 1) Start the trajectory in ERK-RKIPP_RP space from another point A'(0.03, 0.02) 2) clamp node RKIP to 1.0 during 10 first seconds and then to 0.1 during 10 additional seconds at t=400 3) change the value of the kinematic parameter X from X to X

Intervention 1: changing the initial species amount

# Create the intervention
controlled_intervals = [[0, system_rollout_config.deltaT/2.0]]
controlled_node_names = ["ERK", "RKIPP_RP"]
controlled_node_values = [[0.03], [0.02]]

intervention_fn_1 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_1 = DictTree() # A DictTree is a Dict container (i.e. dictionnary where items can be get and set like attributes) that is registered as a Jax PyTree
for (node_name, node_value) in zip(controlled_node_names, controlled_node_values):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    intervention_params_1.y[node_idx] = jnp.array(node_value)

# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_1, intervention_params=intervention_params_1)
Figure 3: Simulation results for modified initial condition A'. Trajectory is shown for nodes ERK and RKIPP_RP in transcriptional space.

👉 Interestingly we can see that despite enforcing the network to start at another point A', the network still converges to the same point B. Moreover, instead of directly going from A' to B, the network seems to do a "detour" to reach back its initial trajectory, and then follow again the trajectory succesfully until point B.

Intervention 2: clamping the species amount to specific values

# Create the intervention
controlled_intervals = [[0, 10], [400, 410]]
controlled_node_names = ["MEKPP"]
controlled_node_values = [[2.5, 1.0]]

intervention_fn_2 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_2 = DictTree() # A DictTree is a Dict container (i.e. dictionnary where items can be get and set like attributes) that is registered as a Jax PyTree
for (node_name, node_value) in zip(controlled_node_names, controlled_node_values):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    intervention_params_2.y[node_idx] = jnp.array(node_value)

# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_2, intervention_params=intervention_params_2)
Figure 4: Simulation results for default initial condition A, with clamping of node MEKPP. (left) Evolution of RKIP (with clamp interventions) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

👉 Here, the clamping of MEKPP seems to have some effect on ERK but not on RKIPP_RP. Indeed, after the initial clamping of MEKPP to 2.5 (during 10 seconds), the trajectory of the ERK-RKIPP_RP pair still follows a very similar S-shape curve, and arrives close to the original B point but with slight lower ERK expression level (t=400). From the moment we re-clamp MEKPP to lower activation (1.0 for 10 seconds at t=400), we see an effect on ERK expression level where the final steady state B' gets shifted to the right.

Intervention 3: changing the kinematic parameters

# Create the intervention
controlled_intervals = [[0, system_rollout_config.deltaT/2.0]]
controlled_param_names = ["k5"]
controlled_param_values = [[0.1]]

intervention_fn_3 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_3 = DictTree() 
for (param_name, param_value) in zip(controlled_param_names, controlled_param_values):
    param_idx = system_rollout.grn_step.c_indexes[param_name] #this time we specify intervention parameter value for the key "c"
    intervention_params_3.c[param_idx] = jnp.array(param_value) 
    print(f"Initial {param_name} param value: {system_rollout.c[param_idx]}, changed to {param_value}")


# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_3, intervention_params=intervention_params_3)
Initial k5 param value: 0.03150000050663948, changed to [0.1]
Figure 5: Simulation results for default initial condition A and modified kinematic parameter k5. Trajectory of (ERK, RKIPP_RP) is shown in transcriptional space.

👉 Here we can see changing the parameter k5 shifts the trajectory end point quite significantly from B to B', but that qualitatively the trajectory seems to keep a similar "S" shape.

Applying perturbations

Similarly, we would like to apply several perturbations on the system rollout to see how this influence the trajectory of the network in transcriptional space. In AutoDiscJax, a perturbation is implemented in the same manner than an intervention (by specifying a perturbation function that is also called at everystep during a rollout). The difference is only conceptual, as a perturbation is supposed to represent uncontrolled events (such as noise or other environmental stresses) whereas an intervention is supposed to represent controlled events (aka applied by the experimenter, such as drug stimuli).

Several perturbation functions are already provided in AutoDiscJax, and we'll be using them more in depth in Part 4 of this tutorial. For the moment, let's just see how to add dynamical noise to the system rollout.

# Create the perturbation
perturbed_intervals = [[t, t+system_rollout_config.deltaT/2] for t in range(100, 300, 5)]
perturbed_node_names = ["ERK", "RKIPP_RP"]
perturbed_node_values = []
for node_name in perturbed_node_names:
    key, subkey = jrandom.split(key)
    perturbed_node_values.append(list(0.005*jrandom.normal(subkey, shape=(len(range(100, 300, 5)), ))))

perturbation_fn_1 = grn.PiecewiseAddConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=perturbed_intervals))
perturbation_params_1 = DictTree()
for (node_name, node_value) in zip(perturbed_node_names, perturbed_node_values):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    perturbation_params_1.y[node_idx] = jnp.array(node_value)

# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=None, intervention_params=None, perturbation_fn=perturbation_fn_1, perturbation_params=perturbation_params_1)
Figure 6: Simulation results for default initial condition A, with noise perturbations. (left) Noise perturbations on ERK and RKIPP_RP through reaction time. (right) Trajectory in transcriptional space.

👉 Here we can see that, once in its convergence point B, the GRN is pretty robust to noise.

Simulations in batch mode

To finish Part 1 of this tutorial, let's see how AutoDiscJax allows us to perform simulations in parallel.

# Put the system in batch mode
batched_system_rollout = vmap(system_rollout, in_axes=(0, None, 0))

# Create the M=10 interventions (vector of starting positions between minval and maxval)
M = 10
controlled_node_names = ["ERK", "RKIPP_RP"]
controlled_node_minvals = [0.02, 0.02]
controlled_node_maxvals = [0.08, 0.08]
batched_interventions_params_1 = DictTree()
for (node_name, node_minval, node_maxval) in zip(controlled_node_names, controlled_node_minvals, controlled_node_maxvals):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    key, subkey = jrandom.split(key)
    batched_interventions_params_1.y[node_idx] = jrandom.uniform(subkey, shape=(M, 1), minval=node_minval, maxval=node_maxval)

key, *subkeys = jrandom.split(key, num=M + 1)
batched_system_outputs, log_data = batched_system_rollout(jnp.array(subkeys), intervention_fn_1, batched_interventions_params_1)

print(default_system_outputs.ys.shape, batched_system_outputs.ys.shape)
(11, 10000) (10, 11, 10000)

Using jax vmap transformation, the above code automatically vectorizes the call to the system rollout module over different intervention parameters and stores the vectorized results in the batched_system_outputs output variable. This is very a convenient (and fast) way to test several interventions in our the biological network.

Figure 7: Simulation results for different initial condition A0, ..., A9, launched in batch mode. Obtained trajectories of (ERK, RKIPP_RP) are shown in transcriptional space.

This plot shows the M=10 resulting trajectories obtained in the transcriptional space when applying intervention1 for 10 different starting conditions $A0, \dots, A9$.

👉 We can see that despite starting the simulations in 10 different positions (initial amounts of ERK and RKOPP_RP), they all converge to the same steady state point B.

Part 2: Automated experimentation approaches and challenges

batch_size = 100

random_intervention_generator_config = Dict()
random_intervention_generator_config.intervention_type = "set_uniform"
random_intervention_generator_config.controlled_node_ids = list(range(len(default_system_outputs.ys)))
random_intervention_generator_config.controlled_intervals = [[0, system_rollout.deltaT/2.0]]

random_search_discoveries = {}
for r in [1, 10, 100]:
    controlled_node_minvals = default_system_outputs.ys.min(-1)/r
    controlled_node_maxvals = default_system_outputs.ys.max(-1)*r

    intervention_params_tree = DictTree()
    for y_idx in random_intervention_generator_config.controlled_node_ids:
        intervention_params_tree.y[y_idx] = "placeholder"

    random_intervention_generator_config.out_treedef = jtu.tree_structure(intervention_params_tree)
    random_intervention_generator_config.out_shape = jtu.tree_map(lambda _: (len(random_intervention_generator_config.controlled_intervals),),
                                             intervention_params_tree)
    random_intervention_generator_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)

    random_intervention_generator_config.low = DictTree()
    random_intervention_generator_config.high = DictTree()
    for (node_idx, node_minval, node_maxval) in zip(random_intervention_generator_config.controlled_node_ids, controlled_node_minvals, controlled_node_maxvals):
        random_intervention_generator_config.low.y[node_idx] = jnp.array([node_minval])
        random_intervention_generator_config.high.y[node_idx] = jnp.array([node_maxval])

    random_intervention_generator, intervention_fn = create_intervention_module(random_intervention_generator_config)
    batched_random_intervention_generator = vmap(random_intervention_generator)

    # Generate random interventions (batch mode)
    key, *subkeys = jrandom.split(key, num=batch_size+1)
    batched_interventions_params, log_data = batched_random_intervention_generator(jnp.array(subkeys))

    # Rollout the system (batch mode)
    key, *subkeys = jrandom.split(key, num=batch_size+1)
    batched_system_outputs, log_data = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_interventions_params)

    # Store the discoveries
    random_search_discoveries[r] = batched_system_outputs
Figure 8: Random Search discoveries in transcriptional space, for different input parameter range r (r=1, r=10, r=100) each with N=100 exploration runs. (ERK, RKIPP_RP) endpoints (t=1000 secs) of the discovered trajectories are shown in transcriptional space.

👉 with too constrained range (r=1), random search misses a big part of what is feasible (in terms of reachable transcriptional space). Choosing a looser input parameter space range (e.g. r=10), shows that it is possible to find novel steady states (e.g. with ERK > 5). However, this also means the the exploration space is bigger and hence harder to explore. In fact, we can see that for r=10 and r=100, random search is not very efficient as most of the discoveries are localized on the ERK=0 axis (i.e. fall in the the ERK=0 valley/attractor).

# Define loss function (L2 distance to target point)
def evaluate_worker_fn(key, intervention_params, intervention_fn, system_rollout, observed_node_ids, target_point, low, high):

    # rollout the system with parameters
    key, subkey = jrandom.split(key)
    system_outputs, log_data = system_rollout(subkey, intervention_fn, intervention_params)

    # Get trajectory final point
    reached_point = system_outputs.ys[jnp.array(observed_node_ids), -1]

    # Calc L2 distance to target point
    loss = jnp.sqrt((jnp.square((reached_point - target_point)/(high-low))).sum())

    # Append info to log data
    log_data = DictTree()
    log_data.reached_point = reached_point
    log_data.loss = loss

    return loss, log_data

previously_reached_points = random_search_discoveries[100].ys[:, jnp.array(observed_node_ids), -1]
low = jnp.nanmin(previously_reached_points, axis=0)
high = jnp.nanmax(previously_reached_points, axis=0)

target_point_1 = jnp.array([200., 4.])
evaluate_worker_fn_1 = jtu.Partial(evaluate_worker_fn, intervention_fn=intervention_fn, system_rollout=system_rollout, 
                                   observed_node_ids=observed_node_ids, target_point=target_point_1, low=low, high=high)

target_point_2 = jnp.array([200., 10.])
evaluate_worker_fn_2 = jtu.Partial(evaluate_worker_fn, intervention_fn=intervention_fn, system_rollout=system_rollout, observed_node_ids=observed_node_ids, target_point=target_point_2, low=low, high=high)


# Create SGD Optimizer
intervention_optimizer_config = Dict()
intervention_optimizer_config.n_optim_steps = 100
intervention_optimizer_config.n_workers = 1
intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda node: 0.0, random_intervention_generator.low)
intervention_optimizer_config.lr = jtu.tree_map(lambda low, high: 0.01*(high-low), 
                                                random_intervention_generator.low, random_intervention_generator.high)

optimizer = optimizers.SGDOptimizer(random_intervention_generator.out_treedef,
                                    random_intervention_generator.out_shape,
                                    random_intervention_generator.out_dtype,
                                    random_intervention_generator.low,
                                    random_intervention_generator.high,
                                    intervention_optimizer_config.n_optim_steps,
                                    intervention_optimizer_config.n_workers,
                                    intervention_optimizer_config.init_noise_std,
                                    intervention_optimizer_config.lr
                                )
# Start position 1
start_intervention_params_1 = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
    start_intervention_params_1.y[node_idx] = default_system_outputs.ys[node_idx, 0][jnp.newaxis]

# Start position 2
selected_intervention_ids, distances = nearest_neighbors(target_point_1/(high-low), previously_reached_points/(high-low), k=1)
start_intervention_params_2 = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
    start_intervention_params_2.y[node_idx] = random_search_discoveries[100].ys[selected_intervention_ids[0], node_idx, 0][jnp.newaxis]
# Optimization Run 1
print(f"Start position B1 {default_system_outputs.ys[jnp.array(observed_node_ids), -1]} toward target position G1 {target_point_1}")
key, subkey = jrandom.split(key)
optimized_intervention_params_1, log_data_1 = optimizer(subkey, start_intervention_params_1, evaluate_worker_fn_1)

# Optimization Run 2
print(f"Start position B2 {previously_reached_points[selected_intervention_ids][0]} toward target position G1 {target_point_1}")
key, subkey = jrandom.split(key)
optimized_intervention_params_2, log_data_2 = optimizer(subkey, start_intervention_params_2, evaluate_worker_fn_1)

# Optimization Run 3
print(f"Start position B1 {default_system_outputs.ys[jnp.array(observed_node_ids), -1]} toward target position G2 {target_point_2}")
key, subkey = jrandom.split(key)
optimized_intervention_params_3, log_data_3 = optimizer(subkey, start_intervention_params_1, evaluate_worker_fn_2)

# Optimization Run 4
print(f"Start position B2 {previously_reached_points[selected_intervention_ids][0]} toward target position G2 {target_point_2}")
key, subkey = jrandom.split(key)
optimized_intervention_params_4, log_data_4 = optimizer(subkey, start_intervention_params_2, evaluate_worker_fn_2)
Start position B1 [0.03638334 0.05518651] toward target position G1 [200.   4.]
Start position B2 [163.21994   4.17591] toward target position G1 [200.   4.]
Start position B1 [0.03638334 0.05518651] toward target position G2 [200.  10.]
Start position B2 [163.21994   4.17591] toward target position G2 [200.  10.]
Figure 9: (left) L2 loss to target (normalized). (right) Training progress trajectories in transcriptional space.

👉 Interestingly we can see that all optimization runs make progress (training loss decreases) but following specific curves in the transcriptional space (not straight toward the targets). Those curves follow the valleys of the optimization landscape, and we can clearly understand how the choice of the starting point will condition the optimization success, as the optimization run might get stuck in a valley or in a local minima.

For instance here, the default initial point B1 (as given in the SBML file) fails to achieve both targets (G1 and G2): his training loss is reaching a plateau (blue and green curve) and the optimization gets stuck, or at least is making very slow progress (as can be seen by the blue and green trajectories in transcriptional space). Another startint point $B_2$, that was previously found by random search (and selected because it was the closest to G1 among all discoveries), successfully manages to get very close to the target goal G1 (orange trajectory). However, it fails to reach another target point G2 (red trajectory).

👉 This shows the importance of having a good pool of initial discoveries for reaching desired targets with optimization. Because random search is not very efficient in covering the map of possible steady states, it is likely that optimization will fail for many possible targets G. Can we find a more efficient way to populate the pool of discoveries, given the same experimental budget?80

Part 3: Curiosity-driven search as an efficient automated discovery tool

🤔 A bit of context

IMGEP pipeline and modules

Random Intervention Generator

# example: generate a random set of intervention parameters between low and high
key, subkey = jrandom.split(key)
intervention_params, log_data = random_intervention_generator(subkey)
print(jtu.tree_map(lambda node: node.shape, intervention_params))

# example in batch mode
key, *subkeys = jrandom.split(key, num=batch_size+ 1)
interventions_params, log_data =  batched_random_intervention_generator(jnp.array(subkeys))
print(jtu.tree_map(lambda node: node.shape, interventions_params))
{'y': {0: (1,), 1: (1,), 2: (1,), 3: (1,), 4: (1,), 5: (1,), 6: (1,), 7: (1,), 8: (1,), 9: (1,), 10: (1,)}}
{'y': {0: (100, 1), 1: (100, 1), 2: (100, 1), 3: (100, 1), 4: (100, 1), 5: (100, 1), 6: (100, 1), 7: (100, 1), 8: (100, 1), 9: (100, 1), 10: (100, 1)}}

Goal Embedding Encoder

goal_embedding_encoder_config = Dict()
goal_embedding_encoder_config.encoder_type = "filter"
goal_embedding_tree = "placeholder"
goal_embedding_encoder_config.out_treedef = jtu.tree_structure(goal_embedding_tree)
goal_embedding_encoder_config.out_shape = jtu.tree_map(lambda _: (len(observed_node_ids), ), goal_embedding_tree)
goal_embedding_encoder_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, goal_embedding_tree)
goal_embedding_encoder_config.filter_fn = jtu.Partial(lambda system_outputs: system_outputs.ys[..., observed_node_ids, -1])

goal_embedding_encoder = create_goal_embedding_encoder_module(goal_embedding_encoder_config)
batched_goal_embedding_encoder = vmap(goal_embedding_encoder)
# example: encode the default system outputs 
key, subkey = jrandom.split(key)
reached_goal_embedding, log_data = goal_embedding_encoder(subkey, default_system_outputs)
print(reached_goal_embedding)

# example in batch mode: encode system outputs discovered by random search
key, *subkeys = jrandom.split(key, num=batch_size+ 1)
reached_goals_embeddings, log_data = batched_goal_embedding_encoder(jnp.array(subkeys), random_search_discoveries[100])
print(reached_goals_embeddings.shape)
[0.03638334 0.05518651]
(100, 2)

Goal-conditioned Achievement Loss

goal_achievement_loss_config = Dict()
goal_achievement_loss_config.loss_type = "L2" 

goal_achievement_loss = create_goal_achievement_loss_module(goal_achievement_loss_config)
batched_goal_achievement_loss = vmap(goal_achievement_loss)
# example
target_goal_embedding = target_point_1
key, subkey = jrandom.split(key)
gc_loss, log_data = goal_achievement_loss(subkey, reached_goal_embedding, target_goal_embedding)
print(gc_loss)

# example in batch mode
target_goals_embeddings = jnp.tile(target_goal_embedding[jnp.newaxis], (batch_size, 1))
key, *subkeys = jrandom.split(key, num= batch_size+ 1)
gc_losses, log_data = batched_goal_achievement_loss(jnp.array(subkeys), reached_goals_embeddings, target_goals_embeddings)
print(gc_losses.shape)
200.00253
(100,)

Goal Generator

goal_generator_config = DictTree()
goal_generator_config.out_treedef = goal_embedding_encoder_config.out_treedef
goal_generator_config.out_shape = goal_embedding_encoder_config.out_shape
goal_generator_config.out_dtype = goal_embedding_encoder_config.out_dtype
goal_generator_config.low = 0.0
goal_generator_config.high = None
goal_generator_config.generator_type = "hypercube"
goal_generator_config.hypercube_scaling = 1.3

goal_generator = create_goal_generator_module(goal_generator_config)
batched_goal_generator = vmap(goal_generator, in_axes=(0, None, None))
# example 
key, subkey = jrandom.split(key)
next_target_goal, log_data = goal_generator(subkey, target_goals_embeddings, reached_goals_embeddings)
print(next_target_goal)

# example in batch mode
key, *subkeys = jrandom.split(key, num= batch_size+ 1)
next_target_goals_embeddings, log_data = batched_goal_generator(jnp.array(subkeys), target_goals_embeddings, reached_goals_embeddings)
print(next_target_goals_embeddings.shape)
[233.92404     1.6432697]
(100, 2)
Figure 10: IMGEP: uniform goal sampling in the (scaled) hyperrectangle of previously reached goals.

Goal-conditioned Intervention Selector

gc_intervention_selector_config = Dict()
gc_intervention_selector_config.selector_type="nearest_neighbor"
gc_intervention_selector_config.loss_f = goal_achievement_loss.loss_f
gc_intervention_selector_config.k = 1

gc_intervention_selector = create_gc_intervention_selector_module(gc_intervention_selector_config)
batched_gc_intervention_selector = vmap(gc_intervention_selector, in_axes=(0, 0, None))
# example
key, *subkeys = jrandom.split(key, num=batch_size + 1)
source_interventions_ids, log_data = batched_gc_intervention_selector(jnp.array(subkeys), next_target_goals_embeddings, reached_goals_embeddings)
print(source_interventions_ids.shape)
(100,)
Figure 11: IMGEP: goal-conditioned nearest neighbor intervention selection.

Goal-conditioned Intevention Optimizer

gc_intervention_optimizer_config = Dict()
gc_intervention_optimizer_config.out_treedef = random_intervention_generator.out_treedef
gc_intervention_optimizer_config.out_shape = random_intervention_generator.out_shape
gc_intervention_optimizer_config.out_dtype = random_intervention_generator.out_dtype
gc_intervention_optimizer_config.low = random_intervention_generator.low
gc_intervention_optimizer_config.high = random_intervention_generator.high
gc_intervention_optimizer_config.optimizer_type = "EA"
gc_intervention_optimizer_config.n_optim_steps = 1
gc_intervention_optimizer_config.n_workers = 1
gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: 0.1 * (high - low), 
                                                               gc_intervention_optimizer_config.low, gc_intervention_optimizer_config.high)

gc_intervention_optimizer = create_gc_intervention_optimizer_module(gc_intervention_optimizer_config)
null_perturbation_generator, null_perturbation_fn = create_perturbation_module(Dict(perturbation_type="null"))
null_rollout_statistics_encoder = create_rollout_statistics_encoder_module(Dict(statistics_type="null"))
partial_gc_intervention_optimizer = jtu.Partial(gc_intervention_optimizer,
                                        perturbation_generator=null_perturbation_generator, perturbation_fn=null_perturbation_fn,
                                        intervention_fn=intervention_fn, system_rollout=system_rollout,
                                        goal_embedding_encoder=goal_embedding_encoder, goal_achievement_loss=goal_achievement_loss,
                                        rollout_statistics_encoder=null_rollout_statistics_encoder
                                        )
batched_gc_intervention_optimizer = vmap(partial_gc_intervention_optimizer, in_axes=(0, 0, 0, None))
# example: optimize init species amount (from default ones) to reach a ERK-RKIPP_RP steady state B (150,5)
key, subkey = jrandom.split(key)
optimized_intervention_params, log_data = partial_gc_intervention_optimizer(subkey, start_intervention_params_1, target_point_1, reached_goals_embeddings)
print(jtu.tree_map(lambda node: node.shape, optimized_intervention_params))

# example in batch mode: optimize the selected interventions (closest to target) toward their respective targets
previous_interventions_params = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
    previous_interventions_params.y[node_idx] = random_search_discoveries[100].ys[:, node_idx, 0]

start_interventions_params = jtu.tree_map(lambda x: x[source_interventions_ids], previous_interventions_params)

key, *subkeys = jrandom.split(key, num=batch_size + 1)
optimized_interventions_params, log_data = batched_gc_intervention_optimizer(jnp.array(subkeys), start_interventions_params, next_target_goals_embeddings, reached_goals_embeddings)
print(jtu.tree_map(lambda node: node.shape, optimized_interventions_params))
{'y': {0: (1,), 1: (1,), 2: (1,), 3: (1,), 4: (1,), 5: (1,), 6: (1,), 7: (1,), 8: (1,), 9: (1,), 10: (1,)}}
{'y': {0: (100, 1), 1: (100, 1), 2: (100, 1), 3: (100, 1), 4: (100, 1), 5: (100, 1), 6: (100, 1), 7: (100, 1), 8: (100, 1), 9: (100, 1), 10: (100, 1)}}
Figure 12: IMGEP: goal-conditioned intervention optimization (here local diffusion).

👉 Progress toward G3, G7. No progress toward G9 (already quite close to it), but even if new point Z9 if further from G9 it falls in an uncovered area and will be useful for future goals.

Figure 13: IMGEP: discoveries after one iteration.

👉 with this simple example we can grasp already why the IMGEP will be much more efficient in finding diverse possible final states

Run experiment pipeline

   

# Run IMGEP
jax_platform_name = "cpu"
seed = 0
n_random_batches = 2 
n_imgep_batches = 8
batch_size = 20
imgep_experiment_data_save_folder = "data/imgep_data"
if not os.path.exists(os.path.join(imgep_experiment_data_save_folder, "history.pickle")):
    run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
                         imgep_experiment_data_save_folder,
                         random_intervention_generator, intervention_fn,
                         null_perturbation_generator, null_perturbation_fn,
                         system_rollout, null_rollout_statistics_encoder,
                         goal_generator, gc_intervention_selector, gc_intervention_optimizer,
                         goal_embedding_encoder, goal_achievement_loss,
                         out_sanity_check=False, save_modules=False, save_logs=False)

# Run Random Search
rs_experiment_data_save_folder = "data/rs_data"
if not os.path.exists(os.path.join(rs_experiment_data_save_folder, "history.pickle")):
    run_rs_experiment(jax_platform_name, seed, n_random_batches+n_imgep_batches, batch_size, 
                      rs_experiment_data_save_folder,
                      random_intervention_generator, intervention_fn,
                      null_perturbation_generator, null_perturbation_fn,
                      system_rollout, null_rollout_statistics_encoder,
                      out_sanity_check=False, save_modules=False, save_logs=False)
imgep_experiment_history = DictTree.load(os.path.join(imgep_experiment_data_save_folder, "history.pickle"))
imgep_reached_goals_embeddings = imgep_experiment_history.reached_goal_embedding_library
print(imgep_reached_goals_embeddings.shape)
rs_experiment_history = DictTree.load(os.path.join(rs_experiment_data_save_folder, "history.pickle"))
rs_reached_goals_embeddings = rs_experiment_history.system_output_library.ys[:, jnp.array(observed_node_ids), -1]
print(rs_reached_goals_embeddings.shape)
(200, 2)
(200, 2)

👉 Sample efficiency

Intrinsically motivated goal exploration algorithms are designed to autonomously discover the widest range of possible diverse effects that can be produced in an initially unknown system (here our biological network). Thus, a first way to evaluate the exploration algorithm is to measure how well and how fast they cover the state space, and particularly in comparison to random search.

epsilon = 0.033

analytic_bc_space_low = jnp.minimum(jnp.nanmin(imgep_reached_goals_embeddings, 0), jnp.nanmin(rs_reached_goals_embeddings, 0))
analytic_bc_space_high = jnp.maximum(jnp.nanmax(imgep_reached_goals_embeddings, 0), jnp.nanmax(rs_reached_goals_embeddings, 0))

def calc_analytic_bc_coverage(reached_endpoints, epsilon):        
    for step_idx, reached_endpoint in enumerate(reached_endpoints):
        if step_idx == 0:
            union_polygon = Point(reached_endpoints[0]).buffer(epsilon)
            covered_areas = [union_polygon.area]
        else:
            union_polygon = unary_union([union_polygon, Point(reached_endpoint).buffer(epsilon)])
            covered_areas.append(union_polygon.area)

    return union_polygon, covered_areas

imgep_reached_endpoints = (imgep_reached_goals_embeddings-analytic_bc_space_low) / (analytic_bc_space_high-analytic_bc_space_low)
imgep_union_polygon, imgep_covered_areas = calc_analytic_bc_coverage(imgep_reached_endpoints, epsilon=epsilon)
rs_reached_endpoints = (rs_reached_goals_embeddings-analytic_bc_space_low) / (analytic_bc_space_high-analytic_bc_space_low)
rs_union_polygon, rs_covered_areas = calc_analytic_bc_coverage(rs_reached_endpoints, epsilon=epsilon)
Figure 14: Diversity of behaviors discovered by the different algorithms variants. (left) All discovered behaviors (stable endpoints). (middle) Discovered reachable space (union of epsilon-radius balls centered around the discovered endpoints) by random search (pink) and imgep (blue). Results are shown for espilon=0.033. (right) Diversity of behaviors discovered throughout exploration, where the area of the discovered reachable space is used as diversity measure.

👉 Finding Salient Stimuli

# calc clusters in behavior space
clusterer = hdbscan.HDBSCAN(min_cluster_size=10)
cluster_labels = clusterer.fit_predict(imgep_reached_endpoints)

# project sampled params in 2D space with TSNE
imgep_sampled_params = jnp.array(list(imgep_experiment_history.intervention_params_library.y.values())).squeeze().transpose()
rs_sampled_params = jnp.array(list(rs_experiment_history.intervention_params_library.y.values())).squeeze().transpose()
all_sampled_params = jnp.concatenate([imgep_sampled_params, rs_sampled_params])

tsne = TSNE(n_components=2)
all_sampled_params = tsne.fit_transform(all_sampled_params)
imgep_sampled_params, rs_sampled_params = all_sampled_params[:200], all_sampled_params[200:]
#@title [Figure 15]
fig_idx = 15

if nb_mode == "run":

    fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.05)
    fig.update_layout(**default_layout)
    fig.update_annotations(**default_annotation_layout)

    # Add interventions sampled by random search
    fig.add_trace(go.Scatter(x=rs_sampled_params[:,0], y=rs_sampled_params[:,1], showlegend=False, 
                             name="random search", legendgroup="random search", 
                             mode="markers", marker=dict(color=default_colors[0], size=4)), 
                  row=1, col=2)
    eps = 1.5
    poly = unary_union([Point(point).buffer(eps) for point in rs_sampled_params])
    poly = poly.buffer(eps*5, join_style=1).buffer(-eps*5, join_style=1)
    x, y = [], []
    if poly.geom_type == 'MultiPolygon':
        for geom in poly.geoms:
            geom_x, geom_y = geom.exterior.coords.xy
            x.append(np.array(geom_x))
            y.append(np.array(geom_y))
    elif poly.geom_type == 'Polygon':
        geom_x, geom_y = poly.exterior.coords.xy
        x.append(np.array(geom_x))
        y.append(np.array(geom_y))
    for i, (x, y) in enumerate(zip(x,y)):
        fig.add_trace(go.Scatter(x=x, y=y, fill="toself",
                                 name="random search", legendgroup="random search", showlegend=(i==0),
                                 line=dict(color=default_colors[0]), hoverinfo="skip"), row=1, col=2)

    # Add discoveries made by random search
    fig.add_trace(go.Scatter(x=rs_reached_endpoints[:,0], y=rs_reached_endpoints[:,1], showlegend=False, 
                             name="random search", legendgroup="random search", 
                             mode="markers", marker=dict(color=default_colors[0], size=4)), 
                  row=1, col=1)

    # Add points / shape contour per IMGEP cluster
    for label_idx in [-1,0,3,2,1]:

        cluster_point_ids = jnp.where(cluster_labels==label_idx)[0]
        z_points = imgep_reached_endpoints[cluster_point_ids]
        i_points = imgep_sampled_params[cluster_point_ids]

        if label_idx < 0:
            show_shape = False
            marker_size = 4
            color_idx = 7
            name = f"N/A"
        else:
            show_shape = True
            marker_size = 4
            color_idx = [4,3,2,6][label_idx]
            name = f"cluster {label_idx+1}"

        # shape contour
        if show_shape:
            for col_idx, (points, eps) in enumerate(zip([z_points, i_points], [0.05, 1.5])):
                poly = unary_union([Point(point).buffer(eps) for point in points])
                poly = poly.buffer(eps*5, join_style=1).buffer(-eps*5, join_style=1)
                x, y = [], []
                if poly.geom_type == 'MultiPolygon':
                    for geom in poly.geoms:
                        geom_x, geom_y = geom.exterior.coords.xy
                        x.append(np.array(geom_x))
                        y.append(np.array(geom_y))
                elif poly.geom_type == 'Polygon':
                    geom_x, geom_y = poly.exterior.coords.xy
                    x.append(np.array(geom_x))
                    y.append(np.array(geom_y))
                for i, (x, y) in enumerate(zip(x,y)):
                    fig.add_trace(go.Scatter(x=x, y=y, fill="toself", fillcolor=default_colors_shade[color_idx],
                                             name=name, legendgroup=name, showlegend=(i==0)&(col_idx==0),
                                             line=dict(color=default_colors[color_idx]), hoverinfo="skip"), row=1, col=col_idx+1)

        # points
        fig.add_trace(go.Scatter(x=z_points[:,0], y=z_points[:,1],
                                 name=name, legendgroup=name, showlegend=(label_idx==-1),
                                 mode="markers", marker_color=default_colors[color_idx], marker_size=marker_size), 
                      row=1, col=1)
        fig.add_trace(go.Scatter(x=i_points[:,0], y=i_points[:,1], 
                                 name=name, legendgroup=name, showlegend=False,
                                 mode="markers", marker_color=default_colors[color_idx], marker_size=marker_size), 
                      row=1, col=2)


    # Add background shape
    fig.add_vrect(x0=-0.1, x1=1.05, 
                  fillcolor="#BAE1FF", opacity=.8,
                  line=dict(color="#6C8EBF", width=2), 
                  annotation_text="Behavior Space <i>Z</i>", annotation_position="top", annotation_font=dict(color='black', size=14),
                  layer="below", row=1, col=1)
    fig.add_vrect(x0=-40, x1=30, 
                  fillcolor="#F6E785", opacity=.8,
                  line=dict(color="#C79714", width=2), 
                  annotation_text="Intervention Space <i>I</i>", annotation_position="top", annotation_font=dict(color='black', size=14),
                  layer="below", row=1, col=2)


    # Update Layout 
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)


    # Serialize fig to json and save
    if nb_save_outputs:
        fig.write_json(f"figures/tuto1_fig_{fig_idx}.json")


elif nb_mode == "load":
    json_fig = requests.get(f"https://raw.githubusercontent.com/flowersteam/curious-exploration-of-grn-competencies/main/notebooks/figures/tuto1_fig_{fig_idx}.json").text
    fig = plotly.io.from_json(json_fig)

# Display Fig
width, height = 940, 400
t = f"Mapping between Intervention Space and Behavior space. " 

if nb_renderer == "html":
    html_fig = make_html_fig(fig_idx, fig, width, height, t)
    display(HTML(html_fig))

elif nb_renderer == "img":
    img_fig, img_title = make_img_fig(fig_idx, fig, width, height, t)
    display(Image(img_fig))
    display(Markdown(img_title))
Figure 15: Mapping between Intervention Space and Behavior space.

👉 Adaptive Design Choices

Finally we show that IMGEP facilitate the engineer task (e.g. when we are not sure on the apppropriate range for the parameter and/or goal space)

Adaptive Goal Space

Figure 16: IMGEP: adaptive goal space extent.

Adaptive Parameter Space Variant

r= 1

# Random generator initially constrained to the tight range (r=1) of what we know is feasible from the default rollout
# but this time only for the 10% first runs, then we let the algorithm adapt their parameter space extent
n_random_batches = 1
n_imgep_batches = 9
constrained_random_intervention_generator_config = deepcopy(random_intervention_generator_config)
for node_idx in random_intervention_generator_config.controlled_node_ids:
    constrained_random_intervention_generator_config.low.y[node_idx] = default_system_outputs.ys[node_idx].min() / r
    constrained_random_intervention_generator_config.high.y[node_idx] = default_system_outputs.ys[node_idx].max() * r   
constrained_random_intervention_generator, intervention_fn = create_intervention_module(constrained_random_intervention_generator_config)

# Goal-conditionned optimizer: remove (low, high) constraints 
## and set local mutation amplitude based on what we know is feasible (r=1) from the default rollout
adaptive_gc_intervention_optimizer_config = deepcopy(gc_intervention_optimizer_config)
adaptive_gc_intervention_optimizer_config.low = jtu.tree_map(lambda node: jnp.zeros_like(node), gc_intervention_optimizer_config.low)
adaptive_gc_intervention_optimizer_config.high = None
adaptive_gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: r*(high-low), 
                                                                        constrained_random_intervention_generator.low,
                                                                        constrained_random_intervention_generator.high)
adaptive_gc_intervention_optimizer = create_gc_intervention_optimizer_module(adaptive_gc_intervention_optimizer_config)

# Run Adaptive IMGEP variant
adaptive_imgep_experiment_data_save_folder = "data/adaptive_imgep_data"
if not os.path.exists(os.path.join(adaptive_imgep_experiment_data_save_folder, "history.pickle")):
    run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
                         adaptive_imgep_experiment_data_save_folder,
                         constrained_random_intervention_generator, intervention_fn,
                         null_perturbation_generator, null_perturbation_fn,
                         system_rollout, null_rollout_statistics_encoder,
                         goal_generator, gc_intervention_selector, adaptive_gc_intervention_optimizer,
                         goal_embedding_encoder, goal_achievement_loss,
                         out_sanity_check=False, save_modules=False, save_logs=False)


# Run Adaptive RandomMut variant (IMGEP with null goal generation and random intervention selection)
null_goal_generator_config = deepcopy(goal_generator_config)
null_goal_generator_config.hypercube_scaling = 0.0
null_goal_generator = create_goal_generator_module(null_goal_generator_config)

random_intervention_selector_config = deepcopy(gc_intervention_selector_config)
random_intervention_selector_config.selector_type = "random"
random_intervention_selector = create_gc_intervention_selector_module(random_intervention_selector_config)

adaptive_rmut_experiment_data_save_folder = "data/adaptive_rmut_data"
if not os.path.exists(os.path.join(adaptive_rmut_experiment_data_save_folder, "history.pickle")):
    run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
                         adaptive_rmut_experiment_data_save_folder,
                         constrained_random_intervention_generator, intervention_fn,
                         null_perturbation_generator, null_perturbation_fn,
                         system_rollout, null_rollout_statistics_encoder,
                         null_goal_generator, random_intervention_selector, adaptive_gc_intervention_optimizer,
                         goal_embedding_encoder, goal_achievement_loss,
                         out_sanity_check=False, save_modules=False, save_logs=False)
adaptive_imgep_experiment_history = DictTree.load(os.path.join(adaptive_imgep_experiment_data_save_folder, "history.pickle"))
adaptive_imgep_reached_goals_embeddings = adaptive_imgep_experiment_history.reached_goal_embedding_library
adaptive_rmut_experiment_history = DictTree.load(os.path.join(adaptive_rmut_experiment_data_save_folder, "history.pickle"))
adaptive_rmut_reached_goals_embeddings = adaptive_rmut_experiment_history.system_output_library.ys[:, jnp.array(observed_node_ids), -1]

adaptive_analytic_bc_space_low = jnp.minimum(jnp.nanmin(adaptive_imgep_reached_goals_embeddings, 0), jnp.nanmin(adaptive_rmut_reached_goals_embeddings, 0))
adaptive_analytic_bc_space_high = jnp.maximum(jnp.nanmax(adaptive_imgep_reached_goals_embeddings, 0), jnp.nanmax(adaptive_rmut_reached_goals_embeddings, 0))

adaptive_imgep_reached_endpoints = (adaptive_imgep_reached_goals_embeddings-adaptive_analytic_bc_space_low) / (adaptive_analytic_bc_space_high-adaptive_analytic_bc_space_low)
adaptive_imgep_union_polygon, adaptive_imgep_covered_areas = calc_analytic_bc_coverage(adaptive_imgep_reached_endpoints, epsilon=epsilon)
adaptive_rmut_reached_endpoints = (adaptive_rmut_reached_goals_embeddings-analytic_bc_space_low) / (adaptive_analytic_bc_space_high-adaptive_analytic_bc_space_low)
adaptive_rmut_union_polygon, adaptive_rmut_covered_areas = calc_analytic_bc_coverage(adaptive_rmut_reached_endpoints, epsilon=epsilon)
Figure 17: Diversity of behaviors discovered by the adaptive parameter space algorithms variants. (left) All discovered behaviors (stable endpoints). (middle) Discovered reachable space (union of epsilon-radius balls centered around the discovered endpoints) by random search (pink) and imgep (blue). Results are shown for espilon=0.033. (right) Diversity of behaviors discovered throughout exploration, where the area of the discovered reachable space is used as diversity measure.

Part 4: Empirical tests for analyzing navigation competencies

🤔 A bit of context

# We test only 40 last trajectories (to go faster)
eval_system_outputs_library = jtu.tree_map(lambda node: node[-2*batch_size:], imgep_experiment_history.system_output_library)
eval_intervention_params_library = jtu.tree_map(lambda node: node[-2*batch_size:], imgep_experiment_history.intervention_params_library)

Trajectory characteristics

def calc_settling_time(dist_vals, settling_time_threshold):
    # assume normalized dist_vals starting from 1 and finishing at 0
    settling_time = jnp.where(~(dist_vals < settling_time_threshold), size=len(dist_vals), fill_value=-1)[0].max()
    return settling_time


def calc_travelling_time(trajectories):
    distance_travelled = jnp.cumsum(jnp.sqrt(jnp.sum(jnp.diff(trajectories, axis=-1)** 2, axis=-2)), axis=-1)
    distance_travelled = distance_travelled / distance_travelled.max(-1) # normalize between 0 and 1
    T10 = jnp.where(distance_travelled >= 0.1, size=distance_travelled.shape[-1], fill_value=-1)[0][0]
    T90 = jnp.where(distance_travelled >= 0.9, size=distance_travelled.shape[-1], fill_value=-1)[0][0]
    return T10, T90


def calc_trajectories_statistics(trajectories, deltaT, settling_time_threshold):
    trajectories = trajectories[..., 1:] #remove first step (dont wanna take into account big jumps happening in first step)

    # settling time:  first time T such that the distance between y(t) and yfinal ≤ 0.02 × |yfinal – yinit| for t ≥ T
    # normalize such that origin is final point and unit=(end-origin)
    extent = (trajectories.max(-1) - trajectories.min(-1))
    extent = extent.at[extent == 0.].set(1.)
    normalized_trajectories = trajectories / extent[..., jnp.newaxis]
    distance_to_target = jnp.linalg.norm(normalized_trajectories - normalized_trajectories[..., -1][..., jnp.newaxis], axis=1)
    distance_to_target = distance_to_target / distance_to_target[:, 0][:, jnp.newaxis]
    settling_times = vmap(calc_settling_time, in_axes=(0, None))(distance_to_target, settling_time_threshold)

    # travelling time: time it takes for the response to travel from 10% to 90% of the way from yinit to yfinal
    T10s, T90s = vmap(calc_travelling_time)(normalized_trajectories)

    # detours (duration and area)
    detours_duration = []
    detours_area = []
    detours_timesteps = []

    for sample_idx in range(len(distance_to_target)):
        detour_timesteps = []
        detour_duration = 0.
        detour_area = 0.

        if settling_times[sample_idx] > 0:
            cur_distance_to_target = distance_to_target[sample_idx, :settling_times[sample_idx]]
            is_distance_increasing = jnp.concatenate(
                [jnp.array([False]), jnp.diff(cur_distance_to_target) > 0])
            is_distance_decreasing = jnp.concatenate(
                [jnp.array([True]), jnp.diff(cur_distance_to_target) < 0])
            start_detour_timesteps = jnp.where(is_distance_decreasing[:-1] & is_distance_increasing[1:])[0]
            if len(start_detour_timesteps) > 0:
                start_detour_dist_vals = cur_distance_to_target[start_detour_timesteps]
                end_detour_timesteps = []

                for start_detour_timestep, start_detour_dist_val in zip(start_detour_timesteps, start_detour_dist_vals):
                    possible_detour_timesteps = jnp.where((cur_distance_to_target[:-1] >= start_detour_dist_val) &
                                                          (cur_distance_to_target[1:] <= start_detour_dist_val))[0] + 1
                    # take the first time step (after start_detour_timestep) where distance curve is crossing back y=start_detour_dist_val
                    # if no crossing back before settling time, we consider settling time as the end of the detour
                    possible_end_detour_timesteps = possible_detour_timesteps[possible_detour_timesteps > start_detour_timestep]
                    if len(possible_end_detour_timesteps) > 0:
                        end_detour_timestep = possible_end_detour_timesteps[0]
                    else:
                        end_detour_timestep = settling_times[sample_idx]-1
                    end_detour_timesteps.append(end_detour_timestep)

                # calc union of intervals (in case some overlaps due to noise)
                detour_timesteps = jnp.where(jnp.array([(jnp.arange(len(cur_distance_to_target)) >= start) &
                                                        (jnp.arange(len(cur_distance_to_target)) <= end)
                                                        for (start, end) in
                                                        zip(start_detour_timesteps, end_detour_timesteps)]).any(0))[0]
                detour_duration = len(detour_timesteps)

                rel_start_detours_timesteps = jnp.concatenate([jnp.array([0]), jnp.where((detour_timesteps[1:] - detour_timesteps[:-1]) > 1)[0]+1])
                rel_end_detours_timesteps = jnp.roll(rel_start_detours_timesteps-1, -1)

                detour_polygon = Polygon()
                valid_detour_timesteps = jnp.empty((0,))
                for start, end in zip(rel_start_detours_timesteps, rel_end_detours_timesteps):
                    detour_points = normalized_trajectories[sample_idx][:, detour_timesteps[start:end]].transpose()
                    if len(detour_points) >= 3:
                        cur_detour_polygon = Polygon([*detour_points])
                        if cur_detour_polygon.is_valid:
                            detour_polygon = unary_union([detour_polygon, cur_detour_polygon])
                            valid_detour_timesteps = jnp.concatenate([valid_detour_timesteps, detour_timesteps[start:end]])

                detour_area = detour_polygon.area

        detours_timesteps.append(detour_timesteps)
        detours_duration.append(detour_duration)
        detours_area.append(detour_area)   


    trajectories_statistics = DictTree()
    trajectories_statistics.distance_to_target = distance_to_target
    trajectories_statistics.settling_times = settling_times
    trajectories_statistics.T10s = T10s
    trajectories_statistics.T90s = T90s
    trajectories_statistics.detours_timesteps = detours_timesteps
    trajectories_statistics.detours_duration = jnp.array(detours_duration) 
    trajectories_statistics.detours_area = jnp.array(detours_area)

    return trajectories_statistics
deltaT = system_rollout_config.deltaT
settling_time_threshold=0.02
trajectories = eval_system_outputs_library.ys[:, jnp.array(observed_node_ids), :]
trajectories_statistics = calc_trajectories_statistics(trajectories, deltaT, settling_time_threshold)
Figure 18: Trajectories Characteristics.

Robustness tests

# Update the system_rollout to run for a bit longer
eval_system_rollout_config = deepcopy(system_rollout_config)
eval_system_rollout_config.n_secs = int(system_rollout_config.n_secs*1.2) 
eval_system_rollout_config.n_system_steps = int(eval_system_rollout_config.n_secs/eval_system_rollout_config.deltaT)

eval_system_rollout = create_system_rollout_module(eval_system_rollout_config)
batched_eval_system_rollout = vmap(eval_system_rollout, in_axes=(0, None, 0, None, 0))

# perturbation hyperparams
perturbation_min_duration = 50
perturbation_max_duration = 500
T10 = jnp.median((trajectories_statistics.T10s+1))*deltaT
T90 = jnp.median((trajectories_statistics.T90s+1))*deltaT
start = T10
end = min(max(T90, start+perturbation_min_duration), start+perturbation_max_duration)

Perturbation 1: Noise 🔊

# Create the noise perturbation generator module
noise_perturbation_generator_config = Dict()
noise_perturbation_generator_config.perturbation_type = "noise"
noise_perturbation_generator_config.perturbed_intervals = [[t-deltaT/2, t+deltaT/2] for t in jnp.linspace(start, end, 2 + int((end-start)/5))[1:-1]] #add noise every 5 secs between start and end
noise_perturbation_generator_config.perturbed_node_ids = random_intervention_generator_config.controlled_node_ids
noise_perturbation_generator_config.std = 0.005

noise_perturbation_generator, noise_perturbation_fn = create_perturbation_module(noise_perturbation_generator_config)
batched_noise_perturbation_generator = vmap(noise_perturbation_generator)
# example
key, subkey = jrandom.split(key)
noise_perturbation_params, log_data = noise_perturbation_generator(subkey, default_system_outputs)
key, subkey = jrandom.split(key)
noise_default_system_outputs, log_data = system_rollout(subkey, intervention_fn, start_intervention_params_1,
                                                        noise_perturbation_fn, noise_perturbation_params)
print(jtu.tree_map(lambda node: node.shape, noise_perturbation_params))
print(jtu.tree_map(lambda node: node.shape, noise_default_system_outputs))

# example in batch mode
key, *subkeys = jrandom.split(key, num=2*batch_size+1)
noise_perturbations_params, log_data = batched_noise_perturbation_generator(jnp.array(subkeys), eval_system_outputs_library)
key, *subkeys = jrandom.split(key, num=2*batch_size+1)
noise_eval_system_outputs, log_data = batched_eval_system_rollout(jnp.array(subkeys), intervention_fn, eval_intervention_params_library, 
                                                  noise_perturbation_fn, noise_perturbations_params)
print(jtu.tree_map(lambda node: node.shape, noise_perturbations_params))
print(jtu.tree_map(lambda node: node.shape, noise_eval_system_outputs))
{'y': {0: (13,), 1: (13,), 2: (13,), 3: (13,), 4: (13,), 5: (13,), 6: (13,), 7: (13,), 8: (13,), 9: (13,), 10: (13,)}}
{'cs': (12, 10000), 'ts': (10000,), 'ws': (0, 10000), 'ys': (11, 10000)}
{'y': {0: (40, 13), 1: (40, 13), 2: (40, 13), 3: (40, 13), 4: (40, 13), 5: (40, 13), 6: (40, 13), 7: (40, 13), 8: (40, 13), 9: (40, 13), 10: (40, 13)}}
{'cs': (40, 12, 12000), 'ts': (40, 12000), 'ws': (40, 0, 12000), 'ys': (40, 11, 12000)}
Figure 19: Simulation results for default initial condition A, with noise perturbation. (left) Evolution of gene expressions (with noise perturbation) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

Perturbation 2: Push 👊

# Create the push perturbation generator module
push_perturbation_generator_config = Dict()
push_perturbation_generator_config.perturbation_type = "push"
push_perturbation_generator_config.perturbed_intervals = [[int((T90+T10)/2.)-deltaT/2, int((T90+T10)/2.)+deltaT/2]]
push_perturbation_generator_config.perturbed_node_ids = observed_node_ids
push_perturbation_generator_config.magnitude = 0.1

push_perturbation_generator, push_perturbation_fn = create_perturbation_module(push_perturbation_generator_config)
batched_push_perturbation_generator = vmap(push_perturbation_generator)
# example
key, subkey = jrandom.split(key)
push_perturbation_params, log_data = push_perturbation_generator(subkey, default_system_outputs)
key, subkey = jrandom.split(key)
push_default_system_outputs, log_data = system_rollout(subkey, intervention_fn, start_intervention_params_1,
                                                        push_perturbation_fn, push_perturbation_params)

# example in batch mode
key, *subkeys = jrandom.split(key, num=2*batch_size+1)
push_perturbations_params, log_data = batched_push_perturbation_generator(jnp.array(subkeys), eval_system_outputs_library)
key, *subkeys = jrandom.split(key, num=2*batch_size+1)
push_eval_system_outputs, log_data = batched_eval_system_rollout(jnp.array(subkeys), intervention_fn, eval_intervention_params_library, 
                                                  push_perturbation_fn, push_perturbations_params)
Figure 20: Simulation results for default initial condition A, with push perturbation. (left) Evolution of ERK and RKIPP_RP (with push perturbation) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

Perturbation 3: Wall 🚧

# Create the wall perturbation generator module
wall_perturbation_generator_config = Dict()
wall_perturbation_generator_config.perturbation_type = "wall"
wall_perturbation_generator_config.wall_type = "force_field"
wall_perturbation_generator_config.perturbed_intervals = [[0, eval_system_rollout_config.n_secs]]
wall_perturbation_generator_config.perturbed_node_ids = observed_node_ids
wall_perturbation_generator_config.n_walls = 2
wall_perturbation_generator_config.walls_intersection_window = [[0.1, 0.15], [0.85, 0.9]]  # in distance travelled from 0 (start point A) to 1.0 (end point B)
wall_perturbation_generator_config.walls_length_range = [[0.1, 0.1], [0.1, 0.1]]
wall_perturbation_generator_config.walls_sigma = [1e-2, 1e-4]

wall_perturbation_generator, wall_perturbation_fn = create_perturbation_module(wall_perturbation_generator_config)
batched_wall_perturbation_generator = vmap(wall_perturbation_generator)
# example
key, subkey = jrandom.split(key)
wall_perturbation_params, log_data = wall_perturbation_generator(subkey, default_system_outputs)
key, subkey = jrandom.split(key)
wall_default_system_outputs, log_data = system_rollout(subkey, intervention_fn, start_intervention_params_1,
                                                        wall_perturbation_fn, wall_perturbation_params)

# example in batch mode
key, *subkeys = jrandom.split(key, num=2*batch_size+1)
wall_perturbations_params, log_data = batched_wall_perturbation_generator(jnp.array(subkeys), eval_system_outputs_library)
key, *subkeys = jrandom.split(key, num=2*batch_size+1)
wall_eval_system_outputs, log_data = batched_eval_system_rollout(jnp.array(subkeys), intervention_fn, eval_intervention_params_library, 
                                                  wall_perturbation_fn, wall_perturbations_params)
Figure 21: Simulation results for default initial condition A, with wall perturbation. (left) Evolution of ERK and RKIPP_RP (with wall perturbation) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

Evaluating the robustness of the discovered behavioral abilities

test_tasks = {
            "noise_std": [0.001, 0.005, 0.01],
            "noise_period": [10, 5, 1],
            "push_magnitude": [0.05, 0.1, 0.15],
            "push_number": [1, 2, 3],
            "wall_length": [0.05, 0.1, 0.15],
            "wall_number": [1, 2, 3]
        }
def get_perturbation_generator_config(var_name, var_val):
    perturbation_generator_config = Dict()

    if var_name.split("_")[0] == "noise":
        perturbation_generator_config = deepcopy(noise_perturbation_generator_config)

        if var_name.split("_")[1] == "std":
            n_noises = int((end-start)//test_tasks["noise_period"][1])
            perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + n_noises)[1:-1]] #5
            perturbation_generator_config.std = var_val

        elif var_name.split("_")[1] == "period":
            n_noises = int((end - start) // float(var_val))
            perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + n_noises)[1:-1]]
            perturbation_generator_config.std = test_tasks["noise_std"][1]

    elif var_name.split("_")[0] == "push":
        perturbation_generator_config = deepcopy(push_perturbation_generator_config)

        if var_name.split("_")[1] == "magnitude":
            perturbation_generator_config.perturbed_intervals = [[int((T90+T10)/2.)-deltaT/2, int((T90+T10)/2.)+deltaT/2]]
            perturbation_generator_config.magnitude = var_val

        elif var_name.split("_")[1] == "number":
            perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + var_val)[1:-1]]
            perturbation_generator_config.magnitude = test_tasks["push_magnitude"][1]

    elif var_name.split("_")[0] == "wall":
        perturbation_generator_config = deepcopy(wall_perturbation_generator_config)

        if var_name.split("_")[1] == "length":
            perturbation_generator_config.n_walls = 1
            perturbation_generator_config.walls_intersection_window = [[0.1, 0.9]]
            perturbation_generator_config.walls_length_range = [[var_val, var_val]]

        elif var_name.split("_")[1] == "number":
            perturbation_generator_config.n_walls = var_val
            walls_windows_pos = jnp.linspace(0, 1, 2 + var_val)
            walls_spacing = (walls_windows_pos[1] - walls_windows_pos[0]) * 3 / 4
            perturbation_generator_config.walls_intersection_window = [[t - walls_spacing, t + walls_spacing] for t in walls_windows_pos[1:-1]]
            perturbation_generator_config.walls_length_range = [[test_tasks["wall_length"][1], test_tasks["wall_length"][1]]] * var_val

    return perturbation_generator_config
n_perturbations = 3
root_test_save_folder = "data/robustness_tests"

for test_task_var_name, test_task_var_range in test_tasks.items():
    for test_task_var_val in test_task_var_range:
        test_save_folder=f"{root_test_save_folder}/{test_task_var_name}_{test_task_var_val}"
        print(test_save_folder)

        if not os.path.exists(test_save_folder):
            perturbation_generator_config = get_perturbation_generator_config(test_task_var_name, test_task_var_val)
            perturbation_generator, perturbation_fn = create_perturbation_module(perturbation_generator_config)

            run_robustness_tests(jax_platform_name, seed, n_perturbations, test_save_folder,
                                 eval_system_outputs_library, 
                                 eval_intervention_params_library, intervention_fn,
                                 perturbation_generator, perturbation_fn,
                                 eval_system_rollout, null_rollout_statistics_encoder,
                                 out_sanity_check=False, save_modules=False, save_logs=False)
data/robustness_tests/noise_std_0.001
data/robustness_tests/noise_std_0.005
data/robustness_tests/noise_std_0.01
data/robustness_tests/noise_period_10
data/robustness_tests/noise_period_5
data/robustness_tests/noise_period_1
data/robustness_tests/push_magnitude_0.05
data/robustness_tests/push_magnitude_0.1
data/robustness_tests/push_magnitude_0.15
data/robustness_tests/push_number_1
data/robustness_tests/push_number_2
data/robustness_tests/push_number_3
data/robustness_tests/wall_length_0.05
data/robustness_tests/wall_length_0.1
data/robustness_tests/wall_length_0.15
data/robustness_tests/wall_number_1
data/robustness_tests/wall_number_2
data/robustness_tests/wall_number_3
Figure 22: Robustness tests results.
Figure 23: Example robust pathway A->B (gray, prior to perturbation), stressed with various kind of perturbation (dropdown menu). Perturbations with smaller (top), intermediary (middle) and stronger (bottom) intensities/frequencies are shown, for 3 random perturbation generations.

Part 5: Perspectives for reuses of the behavioral catalog

Energy Landscape ⛰

Trajectory-based energy landscapes of gene regulatory networks paper.

    
imgep_trajectories = imgep_experiment_history.system_output_library.ys[:, observed_node_ids, :]
is_valid_bool = ~(jnp.isnan(imgep_trajectories).any(-1).any(-1)) & (imgep_trajectories>=-1e-6).all(-1).all(-1)
imgep_points = jnp.concatenate(imgep_trajectories[is_valid_bool], axis=-1)

rs_trajectories = rs_experiment_history.system_output_library.ys[:, observed_node_ids, :]
is_valid_bool = ~(jnp.isnan(rs_trajectories).any(-1).any(-1)) & (rs_trajectories>=-1e-6).all(-1).all(-1)
rs_points = jnp.concatenate(rs_trajectories[is_valid_bool], axis=-1)


imgep_perturbed_points = []
for task_name in test_tasks.keys():
    for task_var in test_tasks[task_name]:
        test_experiment_history = DictTree.load(os.path.join(f"{root_test_save_folder}/{task_name}_{task_var}", "history.pickle"))
        perturbed_trajectories = test_experiment_history.system_output_library.ys[:, :, observed_node_ids, :].reshape(len(eval_system_outputs_library.ys)*n_perturbations, 2, -1)
        is_valid_bool = ~(jnp.isnan(perturbed_trajectories).any(-1).any(-1)) & (perturbed_trajectories>=-1e-6).all(-1).all(-1)
        imgep_perturbed_points.append(jnp.concatenate(perturbed_trajectories[is_valid_bool], axis=-1))
imgep_perturbed_points = jnp.concatenate(imgep_perturbed_points, axis=-1)

all_points = jnp.concatenate([imgep_points, rs_points, imgep_perturbed_points], axis=-1)
ymin, ymax = all_points.min(-1), all_points.max(-1)
del all_points 


results = {}
for k, points in zip(["(a) Random Search", "(b) IMGEP", " (c) IMGEP+perturbations"], [rs_points, imgep_points, imgep_perturbed_points]):

    H, xedges, yedges = jnp.histogram2d(
        x=points[0, :],
        y=points[1, :],
        bins=10,
        range=jnp.stack([ymin, ymax]).transpose()
    )
    H = H.transpose()

    # Compute probability distribution P
    H = H.at[jnp.where(H == 0)].set(1)
    U = -jnp.log(H / H.sum())

    # Calculate energy Landscape
    bin_sizex = xedges[1] - xedges[0]
    bin_sizey = yedges[1] - yedges[0]
    x = xedges[:-1] + bin_sizex / 2
    y = yedges[:-1] + bin_sizey / 2
    z = U.flatten()
    interp = RegularGridInterpolator((x,y), U, method="cubic", bounds_error=False)
    xi = jnp.linspace(ymin[0], ymax[0], 100)
    yi = jnp.linspace(ymin[1], ymax[1], 100)
    xi, yi = jnp.meshgrid(xi, yi)
    zi = interp((xi, yi)).transpose()

    # save results
    results[k] = (xi, yi, zi)
Figure 24: Trajectory-based energy landscapes constructed from the different set of discoveries: from random search (left), imgep search (second left), robustness tests (right).

Therapeutic Pespectives 💊

#@title [Figure 25]
fig_idx = 25

if nb_mode == "run":

    fig = make_subplots(rows=1, cols=3, subplot_titles=["(a) Most robust pathways", "(b) Proposed Intervention", "(c) After Intervention"])
    fig.update_layout(default_layout)
    fig.update_annotations(default_annotation_layout)

    # "healthy" and "disease" regions polygons
    regions_poly = {}
    for label_idx, region_name in zip([1,2], ("healthy", "disease")):
        cluster_point_ids = jnp.where(cluster_labels==label_idx)[0]
        z_points = imgep_reached_endpoints[cluster_point_ids]
        eps=0.05
        poly = unary_union([Point(point).buffer((eps)) for point in z_points])
        poly = poly.buffer(eps*5, join_style=1).buffer(-eps*5, join_style=1)
        poly = affinity.scale(poly, xfact=(analytic_bc_space_high-analytic_bc_space_low)[0], yfact=(analytic_bc_space_high-analytic_bc_space_low)[1], origin=(0,0,0))
        poly = affinity.translate(poly, xoff=analytic_bc_space_low[0], yoff=analytic_bc_space_low[1])
        regions_poly[region_name] = poly


    # Highlight "disease" and "healthy" zones
    for col_idx in [1,3]:
        for region_name, region_poly in regions_poly.items():
            color_idx = 3 if region_name == "healthy" else 2
            x, y = region_poly.exterior.coords.xy
            fig.add_trace(go.Scatter(x=np.array(x), y=np.array(y), fill="toself", fillcolor=default_colors_shade[color_idx],
                                     name=region_name, legendgroup=region_name, showlegend=(col_idx==1),
                                     line=dict(color=default_colors[color_idx]), hoverinfo="skip"), row=1, col=col_idx)


    # plot robust trajectories
    scaling_vector = np.nanmax(trajectories, (0,2)) - np.nanmin(trajectories, (0,2))

    for sample_idx in rob_sample_ids:
        trajectory, display_ts = downsample_traj(trajectories[sample_idx], scaling_vector, 1e-3)
        fig.add_trace(go.Scatter(x=trajectory[0], y=trajectory[1], showlegend=False, mode='markers', 
                                   marker=Dict(color=system_outputs.ts[display_ts], colorscale=traj_cscale, size=2)),
                     row=1, col=1)
        fig.add_annotation(x=trajectory[0,-1], y=trajectory[1,-1], text=f'#{sample_idx}', row=1, col=1)


    # Highlight target endpoint
    fig.add_trace(go.Scatter(x=[target_endpoint[0, 0]], y=[target_endpoint[0, 1]], name="target endpoint", mode="markers",
                             marker=Dict(symbol="star", size=12, color=default_colors[3])),
                  row=1, col=3)



    # plot discovered intervention
    fig.add_trace(go.Scatter(x=jnp.array(controlled_intervals).flatten(), y=jnp.repeat(best_params.clamping[7][0],2), 
                             name=f"{controlled_node} intervention", mode="lines", line_color=default_colors[7]), 
                  row=1, col=2)

    # plot trajectories after intervention
    for rel_idx, sample_idx in enumerate(disease_pathways_ids):

        # Intervention on disease endpoint to put it back to healthy state
        trajectory, display_ts = downsample_traj(best_outputs[0].ys[rel_idx, jnp.array(observed_node_ids)], scaling_vector, 1e-3)
        marker = Dict(color=system_outputs.ts[display_ts], colorscale=traj_cscale, size=2)
        if rel_idx == len(disease_pathways_ids)-1:
            marker.colorbar=Dict(title="time [secs]", thickness=10, len=0.7, y=-0.15, yanchor="bottom")
        fig.add_trace(go.Scatter(x=trajectory[0], y=trajectory[1], showlegend=False, mode='markers', marker=marker),
                     row=1, col=3)

        fig.add_annotation(x=trajectory[0,0], y=trajectory[1,0], text=f'#{sample_idx}', row=1, col=3)

#         trajectory, display_ts = downsample_traj(best_outputs[1].ys[rel_idx, jnp.array(observed_node_ids)], scaling_vector, 1e-2)
#         fig.add_trace(go.Scatter(x=trajectory[0], y=trajectory[1], showlegend=False, mode='markers', 
#                                    marker=Dict(color=system_outputs.ts[display_ts], colorscale=traj_cscale, size=2)),
#                      row=2, col=1)

#         trajectory, display_ts = downsample_traj(best_outputs[2].ys[rel_idx, jnp.array(observed_node_ids)], scaling_vector, 1e-2)
#         fig.add_trace(go.Scatter(x=trajectory[0], y=trajectory[1], showlegend=False, mode='markers', 
#                                    marker=Dict(color=system_outputs.ts[display_ts], colorscale=traj_cscale, size=2)),
#                      row=2, col=2)

#         trajectory, display_ts = downsample_traj(best_outputs[3].ys[rel_idx, jnp.array(observed_node_ids)], scaling_vector, 1e-2)
#         fig.add_trace(go.Scatter(x=trajectory[0], y=trajectory[1], showlegend=False, mode='markers', 
#                                    marker=Dict(color=system_outputs.ts[display_ts], colorscale=traj_cscale, size=2)),
#                      row=2, col=3)

#         # plot walls
#         for wall_idx in range(2):
#             x=display_wall_params.y[observed_node_ids[0]][rel_idx, wall_idx].squeeze()
#             y=display_wall_params.y[observed_node_ids[1]][rel_idx, wall_idx].squeeze()
#             fig.add_trace(go.Scatter(x=x, y=y, mode="lines", line_color=default_colors[-1], showlegend=False), row=2, col=3)



    # Update Layout
    fig.update_xaxes(default_layout.xaxis, title_text=f"Reaction time (secs)", row=1, col=2)
    fig.update_yaxes(default_layout.yaxis, title_text=f"{controlled_node} [&mu;M]", row=1, col=2)
    for col_idx in [1,3]:
        fig.update_xaxes(default_layout.xaxis, title_text=f"{observed_node_names[0]} [&mu;M]", gridcolor='rgba(0, 0, 0, 0)', row=1, col=col_idx)
        fig.update_yaxes(default_layout.yaxis, title_text=f"{observed_node_names[1]} [&mu;M]", gridcolor='rgba(0, 0, 0, 0)', row=1, col=col_idx)




    # Serialize fig to json and save
    if nb_save_outputs:
         fig.write_json(f"figures/tuto1_fig_{fig_idx}.json")


elif nb_mode == "load":
    json_fig = requests.get(f"https://raw.githubusercontent.com/flowersteam/curious-exploration-of-grn-competencies/main/notebooks/figures/tuto1_fig_{fig_idx}.json").text
    fig = plotly.io.from_json(json_fig)

# Display Fig
width, height = 940, 300
t = f'(a) Most robust identified pathways (average sensitivity <{0.05}) are displayed. ' \
f'We can see that most of them converge toward attractors in the "disease" region (orange). '  \
f'(b) Example stepwise intervention on MEKPP, found with simple random search, that we apply on states stuck in the "disease" region during 100 seconds.'  \
f'(c) The discovered intervention successfully brings back all points from "disease" region closer to the target endpoint in the "healthy" region (green), '  \
f'and this under various tested perturbations (not shown here).'

if nb_renderer == "html":
    html_fig = make_html_fig(fig_idx, fig, width, height, t)
    display(HTML(html_fig))

elif nb_renderer == "img":
    img_fig, img_title = make_img_fig(fig_idx, fig, width, height, t)
    display(Image(img_fig))
    display(Markdown(img_title))
Figure 25: (a) Most robust identified pathways (average sensitivity <0.05) are displayed. We can see that most of them converge toward attractors in the "disease" region (orange). (b) Example stepwise intervention on MEKPP, found with simple random search, that we apply on states stuck in the "disease" region during 100 seconds.(c) The discovered intervention successfully brings back all points from "disease" region closer to the target endpoint in the "healthy" region (green), and this under various tested perturbations (not shown here).

Our behavioral catalog allows to identify robust pathways , e.g. identification of robust pathways toward desired are undesired states, in this case cancer pathways (to check: Loss of Raf Kinase Inhibitor Protein Promotes Cell Proliferation and Migration of Human Hepatoma Cells) which were not detected by simple random search.